{ "cells": [ { "cell_type": "markdown", "source": [ "# Complex Sensitive Features\n", "\n", "The [Simple Pipeline notebook](./Simple Pipeline.ipynb) covered the basic use of the `fairret` library. Now, we'll dive a bit deeper into the sensitive features tensor, including a mix of discrete and continuous features.\n", "\n", "We'll skim over the data loading and model definition this time." ], "metadata": { "collapsed": false }, "id": "4177da6e00f826" }, { "cell_type": "markdown", "source": [ "## Data prep" ], "metadata": { "collapsed": false }, "id": "2336f02f615afcb1" }, { "cell_type": "code", "execution_count": 1, "outputs": [], "source": [ "from folktables import ACSDataSource, ACSIncome, generate_categories\n", "\n", "data_source = ACSDataSource(survey_year='2018', horizon='1-Year', survey='person')\n", "data = data_source.get_data(states=[\"CA\"], download=True)\n", "definition_df = data_source.get_definitions(download=True)\n", "categories = generate_categories(features=ACSIncome.features, definition_df=definition_df)\n", "df_feat, df_labels, _ = ACSIncome.df_to_pandas(data, categories=categories, dummies=True)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-23T15:06:03.074297800Z", "start_time": "2024-07-23T15:05:47.737499700Z" } }, "id": "5aad391248c4779d" }, { "cell_type": "markdown", "source": [ "We'll consider three types of sensitive features: SEX (binary), RAC1P (categorical), and AGEP (continuous).\n", "\n", "Their values look like this:" ], "metadata": { "collapsed": false }, "id": "547a8ba4a90a4e1b" }, { "cell_type": "code", "execution_count": 2, "outputs": [ { "data": { "text/plain": " AGEP SEX_Female SEX_Male RAC1P_Alaska Native alone \\\n0 30 False True False \n1 21 False True False \n2 65 False True False \n3 33 False True False \n4 18 True False False \n\n RAC1P_American Indian alone \\\n0 False \n1 False \n2 False \n3 False \n4 False \n\n RAC1P_American Indian and Alaska Native tribes specified; or American Indian or Alaska Native, not specified and no other races \\\n0 False \n1 False \n2 False \n3 False \n4 False \n\n RAC1P_Asian alone RAC1P_Black or African American alone \\\n0 False False \n1 False False \n2 False False \n3 False False \n4 False False \n\n RAC1P_Native Hawaiian and Other Pacific Islander alone \\\n0 False \n1 False \n2 False \n3 False \n4 False \n\n RAC1P_Some Other Race alone RAC1P_Two or More Races RAC1P_White alone \n0 True False False \n1 False False True \n2 False False True \n3 False False True \n4 False False True ", "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
AGEPSEX_FemaleSEX_MaleRAC1P_Alaska Native aloneRAC1P_American Indian aloneRAC1P_American Indian and Alaska Native tribes specified; or American Indian or Alaska Native, not specified and no other racesRAC1P_Asian aloneRAC1P_Black or African American aloneRAC1P_Native Hawaiian and Other Pacific Islander aloneRAC1P_Some Other Race aloneRAC1P_Two or More RacesRAC1P_White alone
030FalseTrueFalseFalseFalseFalseFalseFalseTrueFalseFalse
121FalseTrueFalseFalseFalseFalseFalseFalseFalseFalseTrue
265FalseTrueFalseFalseFalseFalseFalseFalseFalseFalseTrue
333FalseTrueFalseFalseFalseFalseFalseFalseFalseFalseTrue
418TrueFalseFalseFalseFalseFalseFalseFalseFalseFalseTrue
\n
" }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sens_cols = [col for col in df_feat.columns if (col.split('_')[0] in ['SEX', 'RAC1P', 'AGEP'])]\n", "df_feat[sens_cols].head()" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-23T15:06:03.140372500Z", "start_time": "2024-07-23T15:06:03.074166700Z" } }, "id": "fbfcf03d56d7757e" }, { "cell_type": "code", "execution_count": 3, "outputs": [], "source": [ "feat = df_feat.drop(columns=sens_cols).to_numpy(dtype=\"float\")\n", "sens = df_feat[sens_cols].to_numpy(dtype=\"float\")\n", "label = df_labels.to_numpy(dtype=\"float\")" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-23T15:06:09.911034200Z", "start_time": "2024-07-23T15:06:03.132672800Z" } }, "id": "8140169aecea986b" }, { "cell_type": "markdown", "source": [ "Like in [Simple Pipeline.ipynb](./Simple Pipeline.ipynb), we just treat sensitive features in the same way 'normal' features are always treated in PyTorch: as (N x D) tensors, where N is the number of samples and D the dimensionality. The only difference is that we now have a **mix of continuous and categorical sensitive features**. All other steps remain the same!" ], "metadata": { "collapsed": false }, "id": "62be1fd216fac319" }, { "cell_type": "code", "execution_count": 4, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Shape of the 'normal' features tensor: torch.Size([195665, 804])\n", "Shape of the sensitive features tensor: torch.Size([195665, 12])\n", "Shape of the labels tensor: torch.Size([195665, 1])\n" ] } ], "source": [ "import torch\n", "torch.manual_seed(0)\n", "feat, sens, label = torch.tensor(feat).float(), torch.tensor(sens).float(), torch.tensor(label).float()\n", "print(f\"Shape of the 'normal' features tensor: {feat.shape}\")\n", "print(f\"Shape of the sensitive features tensor: {sens.shape}\")\n", "print(f\"Shape of the labels tensor: {label.shape}\")" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-23T15:06:13.086314100Z", "start_time": "2024-07-23T15:06:09.906335600Z" } }, "id": "88c3bf10d3c396a9" }, { "cell_type": "markdown", "source": [ "## A naive PyTorch pipeline" ], "metadata": { "collapsed": false }, "id": "a6954d8968e4fc7f" }, { "cell_type": "code", "execution_count": 5, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 0, loss: 0.5894676046445966\n", "Epoch: 1, loss: 0.45432546191538375\n", "Epoch: 2, loss: 0.42644635944937664\n", "Epoch: 3, loss: 0.41861366961772245\n", "Epoch: 4, loss: 0.4148668508666257\n", "Epoch: 5, loss: 0.41254297791359323\n", "Epoch: 6, loss: 0.4108568604569882\n", "Epoch: 7, loss: 0.409562521148473\n", "Epoch: 8, loss: 0.4085248096380383\n", "Epoch: 9, loss: 0.40770633627350134\n", "Epoch: 10, loss: 0.40696429484523833\n", "Epoch: 11, loss: 0.4062801259569824\n", "Epoch: 12, loss: 0.405580735920618\n", "Epoch: 13, loss: 0.4048566308338195\n", "Epoch: 14, loss: 0.4040601346641779\n", "Epoch: 15, loss: 0.40290021896362305\n", "Epoch: 16, loss: 0.4015411910756181\n", "Epoch: 17, loss: 0.4003351483649264\n", "Epoch: 18, loss: 0.3992322162569811\n", "Epoch: 19, loss: 0.3981786568959554\n", "Epoch: 20, loss: 0.39722149074077606\n", "Epoch: 21, loss: 0.39636922037849825\n", "Epoch: 22, loss: 0.3955736063265552\n", "Epoch: 23, loss: 0.3948506594557936\n", "Epoch: 24, loss: 0.39417379527973634\n" ] } ], "source": [ "import numpy as np\n", "from torch.utils.data import TensorDataset, DataLoader\n", "\n", "h_layer_dim = 16\n", "lr = 1e-3\n", "batch_size = 1024\n", "nb_epochs = 25\n", "\n", "model = torch.nn.Sequential(\n", " torch.nn.Linear(feat.shape[1], h_layer_dim),\n", " torch.nn.ReLU(),\n", " torch.nn.Linear(h_layer_dim, 1)\n", ")\n", "optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n", "dataset = TensorDataset(feat, sens, label)\n", "dataloader = DataLoader(dataset, batch_size=batch_size)\n", "\n", "for epoch in range(nb_epochs):\n", " losses = []\n", " for batch_feat, batch_sens, batch_label in dataloader:\n", " optimizer.zero_grad()\n", " \n", " logit = model(batch_feat)\n", " loss = torch.nn.functional.binary_cross_entropy_with_logits(logit, batch_label)\n", " loss.backward()\n", " \n", " optimizer.step()\n", " losses.append(loss.item())\n", " print(f\"Epoch: {epoch}, loss: {np.mean(losses)}\")" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-23T15:07:52.813598400Z", "start_time": "2024-07-23T15:06:13.086314100Z" } }, "id": "c606ef4196a662c9" }, { "cell_type": "markdown", "source": [ "## Multi-dimensional bias analysis\n", "\n", "Can we detect any statistical disparities (biases) in the naive model, with respect to our mix of sensitive attributes?\n", "\n", "Instead of considering the pairwise gaps between the statistics of groups, we can approach the problem more generally by setting a target value for the statistic. Luckily, all LinearFractionalStatistics in `fairret` have a principled candidate for such a value: the overall statistic, which considers the entire dataset as a single 'group'. The TruePositiveRate is such a LinearFractionalStatistic." ], "metadata": { "collapsed": false }, "id": "85ea660c0b0f79d6" }, { "cell_type": "code", "execution_count": 6, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The TruePositiveRate for group AGEP is 0.6951881051063538\n", "The TruePositiveRate for group SEX_Female is 0.6863165497779846\n", "The TruePositiveRate for group SEX_Male is 0.7048059105873108\n", "The TruePositiveRate for group RAC1P_Alaska Native alone is 0.7878805994987488\n", "The TruePositiveRate for group RAC1P_American Indian alone is 0.6357859373092651\n", "The TruePositiveRate for group RAC1P_American Indian and Alaska Native tribes specified; or American Indian or Alaska Native, not specified and no other races is 0.6129217147827148\n", "The TruePositiveRate for group RAC1P_Asian alone is 0.7354669570922852\n", "The TruePositiveRate for group RAC1P_Black or African American alone is 0.6464908123016357\n", "The TruePositiveRate for group RAC1P_Native Hawaiian and Other Pacific Islander alone is 0.5953413248062134\n", "The TruePositiveRate for group RAC1P_Some Other Race alone is 0.507019579410553\n", "The TruePositiveRate for group RAC1P_Two or More Races is 0.6937863230705261\n", "The TruePositiveRate for group RAC1P_White alone is 0.705863893032074\n", "The overall TruePositiveRate is tensor([0.6974], grad_fn=)\n", "The maximal absolute difference is 0.1903666853904724\n" ] } ], "source": [ "from fairret.statistic import TruePositiveRate\n", "\n", "statistic = TruePositiveRate()\n", "\n", "naive_pred = torch.sigmoid(model(feat))\n", "naive_stat_per_group = statistic(naive_pred, sens, label)\n", "naive_overall_stat = statistic.overall_statistic(naive_pred, label).squeeze().item()\n", "naive_absolute_diff = torch.abs(naive_stat_per_group - naive_overall_stat)\n", "\n", "for i, col in enumerate(sens_cols):\n", " print(f\"The {statistic.__class__.__name__} for group {col} is {naive_stat_per_group[i]}\")\n", "print(f\"The overall {statistic.__class__.__name__} is {statistic.overall_statistic(naive_pred, label)}\")\n", "print(f\"The maximal absolute difference is {torch.max(naive_absolute_diff)}\")" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-23T15:07:52.957765400Z", "start_time": "2024-07-23T15:07:52.816015200Z" } }, "id": "d6f409252d0a016a" }, { "cell_type": "markdown", "source": [ "As we can see, the biggest outlier in TruePositiveRate is the \"Some Other Race alone\" group. However, our fairrets should try to reduce all disparities.\n", "\n", "Note: the TruePositiveRate computed for the age as a 'group' may seem a bit strange, but it is quite interpretable. It is the rate at which actual positives are predicted as positive, weighed by their age. Hence, if age does not (linearly) influence whether the positive is a true positive, the statistic will be close to the overall statistic." ], "metadata": { "collapsed": false }, "id": "d03793575d22f45e" }, { "cell_type": "markdown", "source": [ "## Bias mitigation in fairret" ], "metadata": { "collapsed": false }, "id": "f063994330c1724f" }, { "cell_type": "code", "execution_count": 7, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 0, loss: 0.8587691346183419\n", "Epoch: 1, loss: 0.7747252158199748\n", "Epoch: 2, loss: 0.7440257190416256\n", "Epoch: 3, loss: 0.7302185222506523\n", "Epoch: 4, loss: 0.723417630729576\n", "Epoch: 5, loss: 0.7181567416215936\n", "Epoch: 6, loss: 0.7141167384882768\n", "Epoch: 7, loss: 0.711620382964611\n", "Epoch: 8, loss: 0.7090622059380015\n", "Epoch: 9, loss: 0.7072434027989706\n", "Epoch: 10, loss: 0.7041833276549975\n", "Epoch: 11, loss: 0.7031577831755081\n", "Epoch: 12, loss: 0.70156757440418\n", "Epoch: 13, loss: 0.7001269410053889\n", "Epoch: 14, loss: 0.6993842964681486\n", "Epoch: 15, loss: 0.6980737828028699\n", "Epoch: 16, loss: 0.6973260877033075\n", "Epoch: 17, loss: 0.6958203539252281\n", "Epoch: 18, loss: 0.6948677546655139\n", "Epoch: 19, loss: 0.6939027289239069\n", "Epoch: 20, loss: 0.6936575947329402\n", "Epoch: 21, loss: 0.691608909672747\n", "Epoch: 22, loss: 0.690773538624247\n", "Epoch: 23, loss: 0.6903554652817547\n", "Epoch: 24, loss: 0.6890416747579972\n" ] } ], "source": [ "from fairret.loss import NormLoss\n", "\n", "norm_loss = NormLoss(statistic)\n", "fairness_strength = 0.1\n", "model = torch.nn.Sequential(\n", " torch.nn.Linear(feat.shape[1], h_layer_dim),\n", " torch.nn.ReLU(),\n", " torch.nn.Linear(h_layer_dim, 1)\n", ")\n", "optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n", "\n", "for epoch in range(nb_epochs):\n", " losses = []\n", " for batch_feat, batch_sens, batch_label in dataloader:\n", " optimizer.zero_grad()\n", " \n", " logit = model(batch_feat)\n", " loss = torch.nn.functional.binary_cross_entropy_with_logits(logit, batch_label)\n", " loss += fairness_strength * norm_loss(logit, batch_sens, batch_label)\n", " loss.backward()\n", " \n", " optimizer.step()\n", " losses.append(loss.item())\n", " print(f\"Epoch: {epoch}, loss: {np.mean(losses)}\")" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-23T15:09:57.518814300Z", "start_time": "2024-07-23T15:07:52.958985600Z" } }, "id": "4268b8b24114805d" }, { "cell_type": "markdown", "source": [ "Let's check the true positive rate per group again..." ], "metadata": { "collapsed": false }, "id": "1f544ff6b8925f6a" }, { "cell_type": "code", "execution_count": 8, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The TruePositiveRate for group AGEP is 0.6903698444366455\n", "The TruePositiveRate for group SEX_Female is 0.6866227388381958\n", "The TruePositiveRate for group SEX_Male is 0.6959303021430969\n", "The TruePositiveRate for group RAC1P_Alaska Native alone is 0.7395374774932861\n", "The TruePositiveRate for group RAC1P_American Indian alone is 0.6677444577217102\n", "The TruePositiveRate for group RAC1P_American Indian and Alaska Native tribes specified; or American Indian or Alaska Native, not specified and no other races is 0.6708988547325134\n", "The TruePositiveRate for group RAC1P_Asian alone is 0.6986129283905029\n", "The TruePositiveRate for group RAC1P_Black or African American alone is 0.674475908279419\n", "The TruePositiveRate for group RAC1P_Native Hawaiian and Other Pacific Islander alone is 0.665993332862854\n", "The TruePositiveRate for group RAC1P_Some Other Race alone is 0.5838647484779358\n", "The TruePositiveRate for group RAC1P_Two or More Races is 0.6922767162322998\n", "The TruePositiveRate for group RAC1P_White alone is 0.7005200982093811\n", "The overall TruePositiveRate is tensor([0.6922], grad_fn=)\n", "The maximal absolute difference is 0.1083303689956665\n" ] } ], "source": [ "pred = torch.sigmoid(model(feat))\n", "stat_per_group = statistic(pred, sens, label)\n", "overall_stat = statistic.overall_statistic(pred, label).squeeze().item()\n", "absolute_diff = torch.abs(stat_per_group - overall_stat)\n", "\n", "for i, col in enumerate(sens_cols):\n", " print(f\"The {statistic.__class__.__name__} for group {col} is {stat_per_group[i]}\")\n", "print(f\"The overall {statistic.__class__.__name__} is {statistic.overall_statistic(pred, label)}\")\n", "print(f\"The maximal absolute difference is {torch.max(absolute_diff)}\")" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-23T15:09:57.624592300Z", "start_time": "2024-07-23T15:09:57.523224800Z" } }, "id": "ff9f8d0fe5247a6a" }, { "cell_type": "markdown", "source": [ "With a small change, the maximal absolute difference between the statistics was reduced from 19% to 11%!\n", "\n", "In fact, all disparities were reduced:" ], "metadata": { "collapsed": false }, "id": "7217872a34eaf3bc" }, { "cell_type": "code", "execution_count": 9, "outputs": [ { "data": { "text/plain": "" }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" }, { "data": { "text/plain": "
", "image/png": "" }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "import pandas as pd\n", "\n", "short_sens_cols = [col[:30] for col in sens_cols]\n", "df = pd.DataFrame({\n", " 'kind': ['naive'] * len(sens_cols) + ['fairret'] * len(sens_cols),\n", " 'true_positive_rate': np.concatenate([naive_stat_per_group.detach(), stat_per_group.detach()]),\n", " 'sensitive_feature': np.concatenate([short_sens_cols, short_sens_cols])\n", "})\n", "sns.barplot(data=df, x='sensitive_feature', y='true_positive_rate', hue='kind')\n", "plt.axhline(y=overall_stat, color='black', linestyle='--')\n", "plt.gca().tick_params(axis='x', rotation=90)\n", "plt.legend(ncols=2)\n" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-23T15:09:58.038205600Z", "start_time": "2024-07-23T15:09:57.627716600Z" } }, "id": "b1815b5326966d2b" } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }